[NVFP4][Dense/MoE] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel#2555
[NVFP4][Dense/MoE] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel#2555zhongbozhu wants to merge 22 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR integrates a new Blackwell-native CUTLASS UMMA fusion kernel ( Key changes:
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[NVFP4Quantizer::quantize_impl] --> B{with_rht?}
B -- No --> C[nvte_quantize_v2\nrow-wise only]
B -- Yes --> D{eligible_for_rht_cast_fusion?\nbf16 input AND rows%64==0 AND cols%128==0}
D -- Yes Fused path --> E[nvte_quantize_with_hadamard_transform\nSingle Blackwell UMMA kernel:\n1. Row-wise quantize input\n2. Apply 16x16 RHT via GEMM\n3. Col-wise quantize + transpose]
D -- No Unfused path --> F[allocate rht_output_t buffer\ncols x rows BF16]
F --> G[quantize_with_rht_unfused_helper]
G --> H{rowwise_usage?}
G --> I{columnwise_usage?}
H -- Yes --> J[nvte_quantize_v2\nrow-wise quantize input directly]
I -- Yes --> K[nvte_hadamard_transform\ncompute RHT of transposed input]
K --> L[nvte_quantize_v2\nquantize RHT output]
E --> M[Output: rowwise FP4 + scale_inv\nAND/OR columnwise FP4 + scale_inv]
J --> M
L --> M
Last reviewed commit: 999fe85 |
c80932f to
fc42825
Compare
|
/te-ci arm L1 |
2bc695e to
6ea9dab
Compare
There was a problem hiding this comment.
Greptile Overview
Greptile Summary
This PR integrates a Cutlass-based fusion kernel that combines row-wise quantization and column-wise RHT (Random Hadamard Transform) + quantization + transpose operations for NVFP4 dense linear layers and shared experts. The key optimization reduces memory bandwidth by reading high-precision input data once instead of twice.
Key Changes
New Fusion Kernel (row_cast_col_hadamard_transform_cast_fusion.cu):
- Implements
nvte_hadamard_transform_cast_fusionAPI that performs both rowwise and columnwise quantization in a single pass - Uses MMA hardware for efficient Hadamard transform computation
- Eligible when input is BF16 with dimensions divisible by 64×128
- Reads pre-computed amax values to calculate FP8 scaling factors
- Supports stochastic rounding and fast math optimization flags
Refactored Quantizer Logic (quantizer.cpp):
- Moved unfused RHT path into
quantize_with_rht_unfused_helpermethod for cleaner code organization - Improved RNG state handling: single RNG state when fusion is used, separate states for rowwise/columnwise when unfused
- Added
NVTE_USE_FAST_MATHenvironment variable support for accelerating high-precision math operations - Eligibility check moved before RNG state generation to avoid unnecessary work
Extended Test Coverage (test_nvfp4_rht_quantize_exact.py):
- Added "columnwise-only" quantization mode testing alongside existing "quantize" and "quantize_transpose" modes
- Tests now validate rowwise/columnwise results conditionally based on the quantization mode
Grouped Quantization Support (cast.cpp):
- Split-quantize path now uses fused kernel when all tensors have 128-aligned dimensions
- Bulk RNG state generation for grouped kernels (single state shared across splits)
- Fast math flag propagation to all quantization configs
Architecture Notes
The fusion provides optimal performance when:
- Input dtype is BF16
- Rows are divisible by 64 (MMA tile requirement)
- Columns are divisible by 128 (MMA tile requirement)
When these conditions aren't met, the code gracefully falls back to the unfused path with separate kernel launches for rowwise and columnwise quantization.
Confidence Score: 4/5
- This PR is safe to merge with minimal risk after addressing documentation and TODO items mentioned in the PR description
- Score of 4 reflects a well-engineered feature with thorough implementation. The code demonstrates good software practices: clean refactoring with extracted helper methods, proper error handling, graceful fallback paths, and comprehensive test coverage including the new columnwise-only mode. The fusion kernel follows established patterns from the grouped quantization PR #2411. Deducted 1 point due to: (1) PR author notes cutlass deprecation warnings need addressing, (2) TODOs remain about potentially defaulting fast math on, and (3) the ~1400 line CUDA kernel file has limited inline documentation for complex template logic
- The main CUDA kernel file (row_cast_col_hadamard_transform_cast_fusion.cu) would benefit from additional inline comments explaining the template parameter switches and MMA computation flow, but no files have critical issues requiring immediate attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/csrc/quantizer.cpp | 4/5 | Refactored NVFP4 quantize_impl to use new fused RHT cast kernel, extracted unfused helper, improved RNG state handling for fused vs unfused paths |
| transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu | 4/5 | New CUDA kernel implementing fused row-cast and column-RHT-transpose-cast using Cutlass MMA hardware for BF16 inputs with 64x128 alignment |
| transformer_engine/common/include/transformer_engine/hadamard_transform.h | 5/5 | Added new API function nvte_hadamard_transform_cast_fusion for dense layer quantization, marked old columnwise function for future deprecation |
| transformer_engine/pytorch/csrc/extensions/cast.cpp | 4/5 | Added NVTE_USE_FAST_MATH env var support in split_quantize for grouped NVFP4 kernels, improved RNG state setup with bulk generation flag |
| tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py | 5/5 | Extended test coverage to support columnwise-only quantization mode, added return_identity parameter to test all three modes |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Quantizer as NVFP4Quantizer
participant API as nvte_hadamard_transform_cast_fusion
participant Kernel as row_col_rht_gemm_ntt_w_sfc
participant AmaxKernel as nvte_hadamard_transform_amax
User->>Quantizer: quantize(input, output)
Quantizer->>Quantizer: Check eligibility (BF16, rows%64==0, cols%128==0)
alt With RHT and eligible for fusion
Quantizer->>AmaxKernel: Compute rowwise & columnwise amax
AmaxKernel-->>Quantizer: amax values populated
alt Stochastic rounding enabled
Quantizer->>Quantizer: Generate RNG state
end
alt Fast math enabled (NVTE_USE_FAST_MATH)
Quantizer->>Quantizer: Set use_fast_math flag
end
Quantizer->>API: Call with input, output, hadamard_matrix, quant_config
API->>Kernel: Launch fused kernel
Kernel->>Kernel: Read amax values
Kernel->>Kernel: Perform rowwise quantization to FP4
Kernel->>Kernel: Compute RHT using MMA hardware
Kernel->>Kernel: Transpose and quantize to FP4
Kernel->>Kernel: Write FP8 scales
Kernel-->>API: Complete
API-->>Quantizer: Return
else Not eligible for fusion
Quantizer->>AmaxKernel: Compute amax
AmaxKernel-->>Quantizer: amax values
alt Rowwise usage
Quantizer->>Quantizer: Call nvte_quantize_v2 for rowwise
end
alt Columnwise usage
Quantizer->>Quantizer: Call nvte_hadamard_transform for RHT
Quantizer->>Quantizer: Call nvte_quantize_v2 for columnwise
end
end
Quantizer-->>User: Quantized output
transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu
Outdated
Show resolved
Hide resolved
transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu
Outdated
Show resolved
Hide resolved
Additional Comments (2)
|
transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu
Show resolved
Hide resolved
011169e to
39f9272
Compare
| // 1. Rowwise quantization | ||
| // 2. RHT followed by columnwise quantization & transpose | ||
| NVTE_SCOPED_GIL_RELEASE({ | ||
| nvte_hadamard_transform_cast_fusion(input.data(), out.data(), rht_matrix_nvte.data(), | ||
| quant_config, stream); |
There was a problem hiding this comment.
consider documenting the performance impact of NVTE_USE_FAST_MATH
Since the PR description strongly recommends enabling fast math for significant performance improvement, consider adding a comment here explaining the expected performance gain and why it's recommended for production use (currently only noted in the PR description).
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0; | ||
| static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0; | ||
| static int constexpr NumMmaThreadCount = kEnableRHTColQuant? 32: 0; | ||
| static int constexpr NumMmaIssueThreadCount = kEnableRHTColQuant? 1: 0; | ||
| static int constexpr NumSchedThreads = 32; | ||
| static int constexpr NumMainloopLoadThreads = 32; | ||
| static int constexpr NumEpilogueThreads = NumEpilogueColQuantThreadCount + NumEpilogueRowQuantThreadCount; |
There was a problem hiding this comment.
consider adding rationale for thread counts
Adding a brief comment explaining why these specific thread counts (32 MMA, 128 col quant, 256 row quant) were chosen would help future maintainers understand the workload distribution design.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT | ||
| bool eligible_for_rht_cast_fusion = |
There was a problem hiding this comment.
consider documenting cols % 128 requirement
While confirmed intentional in previous threads, adding a comment explaining why the dense kernel requires cols % 128 (likely UMMA tile alignment) would prevent future confusion, especially since MOE uses different alignment.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
39f9272 to
fd29f6b
Compare
|
/te-ci arm L1 |
fd29f6b to
0864f99
Compare
|
@ptrendx can we merge this PR soon? |
| // and inconsistently implemented. | ||
| // What math is accelerated? Only the high precision math, so numerical impact is minimal | ||
| // 1. replace x / y by x * (1/y) | ||
| // 2. replace 1 / x by reciporal_approximate_ftz(x) |
There was a problem hiding this comment.
Point 2 is scary. I've heard that in the past there were issues with the FTZ setting for numerics.
There was a problem hiding this comment.
Yeah, however, this part of fp32 math is pretty important for the fp4 quantization cost to be hidden under RHT gemm. Users are expected to know the danger and enable it with NVTE_USE_FAST_MATH.
transformer_engine/common/include/transformer_engine/hadamard_transform.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/include/transformer_engine/hadamard_transform.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu
Show resolved
Hide resolved
transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu
Outdated
Show resolved
Hide resolved
transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu
Outdated
Show resolved
Hide resolved
transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu
Outdated
Show resolved
Hide resolved
| int k_tile_size = 1024; | ||
|
|
||
| // TODO: add support for swizzle sf output | ||
| const bool use_swizzle_sf_output = false; |
There was a problem hiding this comment.
We then should check whether the output tensor requested swizzling and fail if so.
There was a problem hiding this comment.
This one has more complication, because of the requirement of triggering this fusion is M being 64 multiple instead of 128 multiple. However, the swizzle fusion does requires 128 multiple. For moe, we can zero pad, but we cannot zero pad for dense easily. Unless we lift up the requirement of triggering RHT fusion once for all, or we have some more dispatch logic about if we use this fusion, do we trigger swizzle fusion or not.
There was a problem hiding this comment.
Sure, that is something we can solve in the followup, but at least for now we should at least fail explicitly if the output tensor has that option set to make sure that there is no silent wrong answer happening.
| // 2. RHT is enabled | ||
| // 3. Columnwise usage is enabled | ||
| // 4. Rowwise and columnwise quantization are not fused, | ||
| // because within a single kernel we can generate two different random numbers for rowwise and columnwise |
There was a problem hiding this comment.
Did we make sure that when we request the RNG state that we tell it that we will generate twice as many numbers? I don't see any changes to that in the diff.
There was a problem hiding this comment.
The rng per-thread has been updated to be num_tensor * 1024.
There was a problem hiding this comment.
And how many numbers do we actually need here?
There was a problem hiding this comment.
I shouldn't have said num_tensor * 1024, since num_tensor==1.
Within the kernel, which is a persistent kernel, to generate RNG, it's already using the current data tile index, not the CTA index, so the rng number needed here only needs to be account for one CTA: https://github.com/zhongbozhu/TransformerEngine/blob/4f24746b81f011c826dca290f5cc9eb5732e4cf3/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu#L962
For this kernel, there are 4 warps handling colwise quantization, and 8 warps handling rowwise quantization. One tile to process a time is 128x128 elements. So in the worse case for colwise quantize, that's 128 * 128 / 4 / 32 = 128 elements touched per thread, 1024 is way above it.
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
4f24746 to
e85d6de
Compare
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
e85d6de to
3a93a72
Compare
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Additional Comments (1)
In the refactored unfused path, For large models (e.g. hidden_dim=4096, seq_len=2048, BF16) this wastes ~16MB of GPU memory per call and incurs unnecessary allocation overhead. } else {
auto& columnwise_quant_config_to_use =
need_separate_columnwise_rng ? quant_config_columnwise : quant_config;
// Only allocate the RHT transposed buffer if columnwise quantization is needed
at::Tensor rht_output_t;
TensorWrapper rht_output_t_cpp;
if (columnwise_usage) {
rht_output_t =
allocateTorchTensor(static_cast<int>(cols), static_cast<int>(rows), input.dtype());
rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(),
std::vector<size_t>{cols, rows});
}
this->quantize_with_rht_unfused_helper(input, out, rht_output_t_cpp, quant_config,
columnwise_quant_config_to_use, stream);
} |
|
/te-ci arm L1 |
Additional Comments (3)
Inside Consider guarding the allocation: at::Tensor rht_output_t;
TensorWrapper rht_output_t_cpp;
if (columnwise_usage) {
rht_output_t =
allocateTorchTensor(static_cast<int>(cols), static_cast<int>(rows), input.dtype());
rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(),
std::vector<size_t>{cols, rows});
}
this->quantize_with_rht_unfused_helper(input, out, rht_output_t_cpp, quant_config,
columnwise_quant_config_to_use, stream);
When The fused kernel carries significant overhead — TMA pipeline infrastructure, cluster-launch-control scheduler, 512-thread CTAs — when only rowwise quantization is required. For inference workloads that only request rowwise output, dispatching a heavier GEMM kernel is slower than the simpler Consider short-circuiting to if (eligible_for_rht_cast_fusion && columnwise_usage) {
// fused path: rowwise + RHT columnwise in one kernel
NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0,
"RHT matrix is not available.");
auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix);
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_with_hadamard_transform(input.data(), out.data(), rht_matrix_nvte.data(),
quant_config, stream);
});
} else if (eligible_for_rht_cast_fusion && !columnwise_usage) {
// rowwise only – simpler kernel is sufficient
NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); });
} else {
// unfused fallback path
...
}
Since the PR notes that fast-math is "strongly recommended for performance", consider caching this value at construction time in the |
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Additional Comments (1)
In the unfused path, Before this refactoring the allocation was correctly guarded inside |
| #include <cuda_runtime.h> | ||
|
|
||
| #if CUDA_VERSION >= 12080 | ||
| #include "common/common.h" |
There was a problem hiding this comment.
This is for having FP4_TYPE_SUPPORTED defined in this header file, which is later referenced in this ptx.cuh
Expect this fix to solve some compile issues for older CUDA versions.
Description
Note: #2558 reported a bug in #2411. Fix is here #2564: make sure you cherry-pick this one too before it's in main.
Previously, similar optimization has been applied for MOE grouped quantize with RHT in #2411. This PR targets the dense linear layers & shared experts when being quantized to NVFP4. Having this fusion means high precision input only needs to be read once while without this fusion, it needs to be read twice.
Similarly, we have env var NVTE_USE_FAST_MATH to control the numerical behavior of RHT quant fusion kernel to accelerate it further. The fast math is only applied to the high precision math so it will have minimal impact of the training convergence.
What fast-math toggle controls:
Therefore, I DO recommend turn it on since it will significantly improve the RHT kernel performance.
The only reason why it's still not default open is because we want ZERO TOLERNACE test between our CUDA quantize kernels and our pytorch-based emulated quantize references. With fast math toggle turned on, it's hard to pass test with zero tolerance without further investigation of how to relax the test conditions while still providing high confidence of the test case.
TODO items:
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: